Qualcomm AI Engine Direct - remove prefill calibration#17805
Qualcomm AI Engine Direct - remove prefill calibration#17805haowhsu-quic wants to merge 1 commit intopytorch:mainfrom
Conversation
- calibrate kv text decoder only to reduce calibration time
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17805
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 311249c with merge base 0c2ff55 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: qualcomm" |
|
Thanks a lot! In addition to this, I noticed that SeqMSE grid searching is done in sequential instead of parallel, is there room to improve there? |
Yes, will look into it. |
can you share what's broken? Is it the flow to consume torchtune in executorch repo? |
| # transpose first to decrease the runtime efforts | ||
| k_cache.append( | ||
| torch.zeros( | ||
| torch.ones( |
There was a problem hiding this comment.
Are we initialized kv cache with different values?
There was a problem hiding this comment.
This is for checking the numerical value with deterministic input for validation purpose. I can revert them back to zeros.
| if list(decode_node.users)[0].target in ptq_target: | ||
| activation_override(decode_node, prefill_node) | ||
|
|
||
| # copy encoding for hybrid mode |
There was a problem hiding this comment.
Are you copying over the quantization parameters from kv mode to prefill mode?
There was a problem hiding this comment.
Yes, since the current bottleneck of calibration time is using prefill mode for generating user prompts.
I think the future scenario we want to try is to generate all the special tokens for target model and merge them with task calibration data. Have prefill model to iterate over it and use prefill's encoding for kv. We could save the user prompt generation with this approach.
| # | ||
| # however, pytorch will use different computaion kernels for different | ||
| # workloads (AR1 vs ARN) which will introduce some numerical discrepancy. | ||
| # |
There was a problem hiding this comment.
what is the mechanism to make sure the encoding align correctly?
There was a problem hiding this comment.
I'm worried about the accuracy too if we get rid of prefill calibration, do you think if we generate prompt + output using fp32 model (pre observers) as discussed in PR #17786 and run prefill + decode as before with skip_generate might yield better accuracy, rather than getting rid of entire prefill calibration?
There was a problem hiding this comment.
prefill calibration ideally is not needed because decode see all the generated tokens too and prefill graph and decode graph should be the same. I remember @haowhsu-quic mentioned we insert the kv cache of the output of prefill and connect to the decode of the input to make sure those quant nodes are also calibrated. I did a comparison for quant params between prefill and decode in the past and they are very very close. I'm trying to figure out if this PR handle kv cache differently than before.
There was a problem hiding this comment.
I guess my question is prefill "sees" previous tokens and as attention block will take these into consideration while generating kv-cache.
For weights, what you said makes sense as they are idempotent from the math standpoint, maybe we should just check the PPL on-device?
CC: @metascroy, @kimishpatel if you have any thoughts on this.
There was a problem hiding this comment.
The mechanism is to compare the topology order of ops in both graph where each op in the same order should have identical nn_module_stack. The number / type of QDQ pairs in nodes' (call_function / placeholder) users are required to be identical as well.
There was a problem hiding this comment.
Will have a fix for condition prefill_ar_len == max_seq_len, thanks for identifying this.
There was a problem hiding this comment.
Ran comparison over the weekend, observed great speedups overall, calibration time is cut in half overall time is cut by ~33% for the qwen model i'm benchmarking. I will do a final PPL check on-device and we will be good to go. Thanks again for working on this.
I think we also should find ways to cut per-iteration decode timing.
┌─────────────────────┬──────────┬────────┬─────────────────────┐
│ │ Baseline │ PR │ Savings │
├─────────────────────┼──────────┼────────┼─────────────────────┤
│ DECODE calibration │ 2h 59m │ 2h 49m │ ~10m (noise) │
├─────────────────────┼──────────┼────────┼─────────────────────┤
│ PREFILL calibration │ 2h 34m │ 15.5s │ 2h 34m │
├─────────────────────┼──────────┼────────┼─────────────────────┤
│ Quantization total │ 5h 33m │ 2h 50m │ 2h 43m (49% faster) │
├─────────────────────┼──────────┼────────┼─────────────────────┤
│ Compile total │ 2h 35m │ 2h 35m │ ~0 (noise) │
├─────────────────────┼──────────┼────────┼─────────────────────┤
│ End-to-end │ 8h 13m │ 5h 30m │ 2h 43m (33% faster) │
└─────────────────────┴──────────┴────────┴─────────────────────┘
There was a problem hiding this comment.
@haowhsu-quic PPL looks fine.
Let's go with the change. Can you please rebase.
There was a problem hiding this comment.
The mechanism is to compare the topology order of ops in both graph where each op in the same order should have identical nn_module_stack. The number / type of QDQ pairs in nodes' (call_function / placeholder) users are required to be identical as well.
Is there any particular reason that we remove the custom annotation? I understand that the custom annotation caused prefill/decode have different graphs, but how about the quantization params for kv cache?
There was a problem hiding this comment.
I think quantization params for placeholders (kv cache) will also be shared between prefill & decode. They should be the same under identical calibration data. May I learn more about your concern?
I think we're using torchtune to convert parameter naming in some llms. Since torchao has deprecated the |
Summary
Total Quantization Time
Test plan
python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript / TestExampleMultimodalityScript